# Standard
import json
import os
import re
from dataclasses import dataclass
from typing import Union

# Third Party
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm.utils.network_utils import get_ip

from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.distributed.mooncake_transfer_engine import global_te

DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200  # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824  # 1.0 GiB


class MooncakeBackend(Backend):

    def __init__(self, parallel_config: ParallelConfig):
        try:
            from mooncake.store import MooncakeDistributedStore  # type: ignore
        except ImportError as e:
            raise ImportError(
                "Please install mooncake by following the instructions at "
                "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "  # noqa: E501
                "to run vLLM with MooncakeConnector.") from e
        self.config = MooncakeStoreConfig.load_from_env()
        self.store = MooncakeDistributedStore()
        if self.config.protocol == "ascend":
            local_hostname = get_ip()
            transfer_engine = global_te.get_transfer_engine(local_hostname,
                                                            device_name=None)
            self.local_seg = local_hostname + ":" + str(
                transfer_engine.get_rpc_port())
            ret = self.store.setup(self.local_seg, self.config.metadata_server,
                                   self.config.global_segment_size,
                                   self.config.local_buffer_size,
                                   self.config.protocol,
                                   self.config.device_name,
                                   self.config.master_server_address,
                                   transfer_engine.get_engine())
        if ret != 0:
            msg = "Initialize mooncake failed."
            logger.error(msg)
            raise RuntimeError(msg)

    def register_buffer(self, ptrs: list[int], lengths: list[int]):
        global_te.register_buffer(ptrs, lengths)

    def exists(self, keys: list[str]) -> list[int]:
        return self.store.batch_is_exist(keys)

    def put(self, keys: list[str], addrs: list[list[int]],
            sizes: list[list[int]]):
        try:
            res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
            for value in res:
                if value < 0:
                    logger.error(f"Failed to put key {keys},res:{res}")
        except Exception as e:
            logger.error(f"Failed to put key {keys},error:{e}")

    def get(self, keys: list[str], addrs: list[list[int]],
            sizes: list[list[int]]):
        try:
            res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
            for value in res:
                if value < 0:
                    logger.error(f"Failed to get key {keys}, res:{res}")
        except Exception as e:
            logger.error(f"Failed to get key {keys}, error:{e}")


@dataclass
class MooncakeStoreConfig:
    local_hostname: str
    metadata_server: str
    global_segment_size: Union[int, str]
    local_buffer_size: int
    protocol: str
    device_name: str
    master_server_address: str

    @staticmethod
    def from_file(file_path: str) -> "MooncakeStoreConfig":
        with open(file_path) as file:
            config = json.load(file)
        return MooncakeStoreConfig(
            local_hostname=config.get("local_hostname"),
            metadata_server=config.get("metadata_server"),
            global_segment_size=_parse_global_segment_size(
                config.get("global_segment_size",
                           DEFAULT_GLOBAL_SEGMENT_SIZE)),
            local_buffer_size=(config.get("local_buffer_size",
                                          DEFAULT_LOCAL_BUFFER_SIZE)),
            protocol=config.get("protocol", "tcp"),
            device_name=config.get("device_name", ""),
            master_server_address=config.get("master_server_address"))

    @staticmethod
    def load_from_env() -> "MooncakeStoreConfig":
        config_path = os.getenv("MOONCAKE_CONFIG_PATH")
        if not config_path:
            raise ValueError(
                "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
        return MooncakeStoreConfig.from_file(config_path)


def _parse_global_segment_size(value) -> int:
    """
    Parse storage size strings with support for units: GB, MB, KB, B
    
    Args:
        value: Input value (int, str, or other convertible types)
        
    Returns:
        int: Size in bytes
        
    Raises:
        ValueError: For invalid format, missing number, or negative values
        TypeError: For unsupported input types
    """

    if isinstance(value, int):
        return value
    elif not isinstance(value, str):
        try:
            return int(value)
        except (TypeError, ValueError) as e:
            raise TypeError(
                f"Unsupported type for global_segment_size: {type(value)}"
            ) from e

    cleaned_input = value.strip().lower()
    if not cleaned_input:
        raise ValueError("global segment size cannot be empty.")

    UNIT_MULTIPLIERS = {
        'gb': 1024**3,  # 1 GB = 1024^3 bytes
        'mb': 1024**2,  # 1 MB = 1024^2 bytes
        'kb': 1024,  # 1 KB = 1024 bytes
        'b': 1  # 1 B = 1 byte
    }
    pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
    match = re.match(pattern, cleaned_input)

    if not match:
        raise ValueError(f"Invalid format: '{value}'")

    number_str = match.group(1)
    unit = match.group(2) or 'b'

    multiplier = UNIT_MULTIPLIERS[unit]
    return _convert_to_bytes(number_str, multiplier, value)


def _convert_to_bytes(number_str: str, multiplier: int,
                      original_input: str) -> int:
    """
    Convert numeric string to byte count
    
    Args:
        number_str: Numeric portion of input
        multiplier: Unit conversion factor
        original_input: Original input string (for error messages)
        
    Returns:
        int: Byte count
        
    Raises:
        ValueError: For invalid numbers or negative results
    """
    try:
        numeric_value = float(number_str)
    except ValueError:
        raise ValueError(
            f"Invalid numeric value '{number_str}' in: '{original_input}'")
    # Calculate byte count
    try:
        byte_count = int(numeric_value * multiplier)
    except OverflowError:
        raise ValueError(f"Storage size too large: '{original_input}'")
    return byte_count
